{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "fe6fd0ab0e1ad844",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "# Transfer Learning with Sentence Transformers and Scikit-Learn"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8dcde44d942793ff",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Introduction\n",
    "\n",
    "In this notebook, we will explore the process of transfer learning using SuperDuperDB. We will demonstrate how to connect to a MongoDB datastore, load a dataset, create a SuperDuperDB model based on Sentence Transformers, train a downstream model using Scikit-Learn, and apply the trained model to the database. Transfer learning is a powerful technique that can be used in various applications, such as vector search and downstream learning tasks."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1809feca8a8dca5a",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "## Prerequisites\n",
    "\n",
    "Before diving into the implementation, ensure that you have the necessary libraries installed by running the following commands:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94f3219ad932a327",
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install superduperdb\n",
    "!pip install ipython numpy datasets sentence-transformers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6bc151f6",
   "metadata": {},
   "source": [
    "## Connect to datastore "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5379007991707d17",
   "metadata": {
    "collapsed": false,
    "jupyter": {
     "outputs_hidden": false
    }
   },
   "source": [
    "First, we need to establish a connection to a MongoDB datastore via SuperDuperDB. You can configure the `MongoDB_URI` based on your specific setup. \n",
    "Here are some examples of MongoDB URIs:\n",
    "\n",
    "* For testing (default connection): `mongomock://test`\n",
    "* Local MongoDB instance: `mongodb://localhost:27017`\n",
    "* MongoDB with authentication: `mongodb://superduper:superduper@mongodb:27017/documents`\n",
    "* MongoDB Atlas: `mongodb+srv://<username>:<password>@<atlas_cluster>/<database>`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "44f8ef76",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[32m 2023-Dec-13 11:12:52.51\u001b[0m| \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.base.build\u001b[0m:\u001b[36m50  \u001b[0m | \u001b[34m\u001b[1mParsing data connection URI:mongomock://test\u001b[0m\n",
      "\u001b[32m 2023-Dec-13 11:12:52.52\u001b[0m| \u001b[1mINFO    \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.base.build\u001b[0m:\u001b[36m137 \u001b[0m | \u001b[1mData Client is ready. mongomock.MongoClient('localhost', 27017)\u001b[0m\n",
      "\u001b[32m 2023-Dec-13 11:12:52.53\u001b[0m| \u001b[1mINFO    \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.base.datalayer\u001b[0m:\u001b[36m79  \u001b[0m | \u001b[1mBuilding Data Layer\u001b[0m\n"
     ]
    }
   ],
   "source": [
    "from superduperdb import superduper\n",
    "from superduperdb.backends.mongodb import Collection\n",
    "import os\n",
    "\n",
    "mongodb_uri = os.getenv(\"MONGODB_URI\", \"mongomock://test\")\n",
    "\n",
    "# SuperDuperDB, now handles your MongoDB database\n",
    "# It just super dupers your database\n",
    "db = superduper(mongodb_uri, artifact_store='filesystem://./data/')\n",
    "\n",
    "# Reference a collection called transfer\n",
    "collection = Collection('transfer')"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97fede97",
   "metadata": {},
   "source": [
    "## Load Dataset\n",
    "\n",
    "Transfer learning can be applied to any data that can be processed with SuperDuperDB models.\n",
    "For our example, we will use a labeled textual dataset with sentiment analysis.  We'll load a subset of the IMDb dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8bb65106",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[32m 2023-Dec-13 11:12:55.72\u001b[0m| \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.base.datalayer\u001b[0m:\u001b[36m716 \u001b[0m | \u001b[34m\u001b[1mBuilding task workflow graph. Query:<superduperdb.backends.mongodb.query.MongoCompoundSelect[\n",
      "    \u001b[92m\u001b[1mtransfer.find({}, {})\u001b[0m}\n",
      "] object at 0x15630f490>\u001b[0m\n",
      "\u001b[32m 2023-Dec-13 11:12:55.72\u001b[0m| \u001b[1mINFO    \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.backends.local.compute\u001b[0m:\u001b[36m32  \u001b[0m | \u001b[1mSubmitting job. function:<function callable_job at 0x110919120>\u001b[0m\n",
      "\u001b[32m 2023-Dec-13 11:12:55.72\u001b[0m| \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.misc.download\u001b[0m:\u001b[36m337 \u001b[0m | \u001b[34m\u001b[1m{'cls': 'MongoCompoundSelect', 'dict': {'table_or_collection': {'cls': 'Collection', 'dict': {'identifier': 'transfer'}, 'module': 'superduperdb.backends.mongodb.query'}, 'pre_like': None, 'post_like': None, 'query_linker': {'cls': 'MongoQueryLinker', 'dict': {'table_or_collection': {'cls': 'Collection', 'dict': {'identifier': 'transfer'}, 'module': 'superduperdb.backends.mongodb.query'}, 'members': [{'cls': 'Find', 'dict': {'name': 'find', 'type': <QueryType.QUERY: 'query'>, 'args': [{}, {}], 'kwargs': {}, 'output_fields': None}, 'module': 'superduperdb.backends.mongodb.query'}]}, 'module': 'superduperdb.backends.mongodb.query'}}, 'module': 'superduperdb.backends.mongodb.query'}\u001b[0m\n",
      "\u001b[32m 2023-Dec-13 11:12:55.72\u001b[0m| \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.misc.download\u001b[0m:\u001b[36m338 \u001b[0m | \u001b[34m\u001b[1m[ObjectId('657983a788bef09ab851d97a'), ObjectId('657983a788bef09ab851d97b'), ObjectId('657983a788bef09ab851d97c'), ObjectId('657983a788bef09ab851d97d'), ObjectId('657983a788bef09ab851d97e'), ObjectId('657983a788bef09ab851d97f'), ObjectId('657983a788bef09ab851d980'), ObjectId('657983a788bef09ab851d981'), ObjectId('657983a788bef09ab851d982'), ObjectId('657983a788bef09ab851d983'), ObjectId('657983a788bef09ab851d984'), ObjectId('657983a788bef09ab851d985'), ObjectId('657983a788bef09ab851d986'), ObjectId('657983a788bef09ab851d987'), ObjectId('657983a788bef09ab851d988'), ObjectId('657983a788bef09ab851d989'), ObjectId('657983a788bef09ab851d98a'), ObjectId('657983a788bef09ab851d98b'), ObjectId('657983a788bef09ab851d98c'), ObjectId('657983a788bef09ab851d98d'), ObjectId('657983a788bef09ab851d98e'), ObjectId('657983a788bef09ab851d98f'), ObjectId('657983a788bef09ab851d990'), ObjectId('657983a788bef09ab851d991'), ObjectId('657983a788bef09ab851d992'), ObjectId('657983a788bef09ab851d993'), ObjectId('657983a788bef09ab851d994'), ObjectId('657983a788bef09ab851d995'), ObjectId('657983a788bef09ab851d996'), ObjectId('657983a788bef09ab851d997'), ObjectId('657983a788bef09ab851d998'), ObjectId('657983a788bef09ab851d999'), ObjectId('657983a788bef09ab851d99a'), ObjectId('657983a788bef09ab851d99b'), ObjectId('657983a788bef09ab851d99c'), ObjectId('657983a788bef09ab851d99d'), ObjectId('657983a788bef09ab851d99e'), ObjectId('657983a788bef09ab851d99f'), ObjectId('657983a788bef09ab851d9a0'), ObjectId('657983a788bef09ab851d9a1'), ObjectId('657983a788bef09ab851d9a2'), ObjectId('657983a788bef09ab851d9a3'), ObjectId('657983a788bef09ab851d9a4'), ObjectId('657983a788bef09ab851d9a5'), ObjectId('657983a788bef09ab851d9a6'), ObjectId('657983a788bef09ab851d9a7'), ObjectId('657983a788bef09ab851d9a8'), ObjectId('657983a788bef09ab851d9a9'), ObjectId('657983a788bef09ab851d9aa'), ObjectId('657983a788bef09ab851d9ab'), ObjectId('657983a788bef09ab851d9ac'), ObjectId('657983a788bef09ab851d9ad'), ObjectId('657983a788bef09ab851d9ae'), ObjectId('657983a788bef09ab851d9af'), ObjectId('657983a788bef09ab851d9b0'), ObjectId('657983a788bef09ab851d9b1'), ObjectId('657983a788bef09ab851d9b2'), ObjectId('657983a788bef09ab851d9b3'), ObjectId('657983a788bef09ab851d9b4'), ObjectId('657983a788bef09ab851d9b5'), ObjectId('657983a788bef09ab851d9b6'), ObjectId('657983a788bef09ab851d9b7'), ObjectId('657983a788bef09ab851d9b8'), ObjectId('657983a788bef09ab851d9b9'), ObjectId('657983a788bef09ab851d9ba'), ObjectId('657983a788bef09ab851d9bb'), ObjectId('657983a788bef09ab851d9bc'), ObjectId('657983a788bef09ab851d9bd'), ObjectId('657983a788bef09ab851d9be'), ObjectId('657983a788bef09ab851d9bf'), ObjectId('657983a788bef09ab851d9c0'), ObjectId('657983a788bef09ab851d9c1'), ObjectId('657983a788bef09ab851d9c2'), ObjectId('657983a788bef09ab851d9c3'), ObjectId('657983a788bef09ab851d9c4'), ObjectId('657983a788bef09ab851d9c5'), ObjectId('657983a788bef09ab851d9c6'), ObjectId('657983a788bef09ab851d9c7'), ObjectId('657983a788bef09ab851d9c8'), ObjectId('657983a788bef09ab851d9c9'), ObjectId('657983a788bef09ab851d9ca'), ObjectId('657983a788bef09ab851d9cb'), ObjectId('657983a788bef09ab851d9cc'), ObjectId('657983a788bef09ab851d9cd'), ObjectId('657983a788bef09ab851d9ce'), ObjectId('657983a788bef09ab851d9cf'), ObjectId('657983a788bef09ab851d9d0'), ObjectId('657983a788bef09ab851d9d1'), ObjectId('657983a788bef09ab851d9d2'), ObjectId('657983a788bef09ab851d9d3'), ObjectId('657983a788bef09ab851d9d4'), ObjectId('657983a788bef09ab851d9d5'), ObjectId('657983a788bef09ab851d9d6'), ObjectId('657983a788bef09ab851d9d7'), ObjectId('657983a788bef09ab851d9d8'), ObjectId('657983a788bef09ab851d9d9'), ObjectId('657983a788bef09ab851d9da'), ObjectId('657983a788bef09ab851d9db'), ObjectId('657983a788bef09ab851d9dc'), ObjectId('657983a788bef09ab851d9dd')]\u001b[0m\n",
      "\u001b[32m 2023-Dec-13 11:12:55.73\u001b[0m| \u001b[32m\u001b[1mSUCCESS \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.backends.local.compute\u001b[0m:\u001b[36m38  \u001b[0m | \u001b[32m\u001b[1mJob submitted.  function:<function callable_job at 0x110919120> future:71f1e105-0b4b-46fa-9d8d-66d9969535cc\u001b[0m\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([ObjectId('657983a788bef09ab851d97a'),\n",
       "  ObjectId('657983a788bef09ab851d97b'),\n",
       "  ObjectId('657983a788bef09ab851d97c'),\n",
       "  ObjectId('657983a788bef09ab851d97d'),\n",
       "  ObjectId('657983a788bef09ab851d97e'),\n",
       "  ObjectId('657983a788bef09ab851d97f'),\n",
       "  ObjectId('657983a788bef09ab851d980'),\n",
       "  ObjectId('657983a788bef09ab851d981'),\n",
       "  ObjectId('657983a788bef09ab851d982'),\n",
       "  ObjectId('657983a788bef09ab851d983'),\n",
       "  ObjectId('657983a788bef09ab851d984'),\n",
       "  ObjectId('657983a788bef09ab851d985'),\n",
       "  ObjectId('657983a788bef09ab851d986'),\n",
       "  ObjectId('657983a788bef09ab851d987'),\n",
       "  ObjectId('657983a788bef09ab851d988'),\n",
       "  ObjectId('657983a788bef09ab851d989'),\n",
       "  ObjectId('657983a788bef09ab851d98a'),\n",
       "  ObjectId('657983a788bef09ab851d98b'),\n",
       "  ObjectId('657983a788bef09ab851d98c'),\n",
       "  ObjectId('657983a788bef09ab851d98d'),\n",
       "  ObjectId('657983a788bef09ab851d98e'),\n",
       "  ObjectId('657983a788bef09ab851d98f'),\n",
       "  ObjectId('657983a788bef09ab851d990'),\n",
       "  ObjectId('657983a788bef09ab851d991'),\n",
       "  ObjectId('657983a788bef09ab851d992'),\n",
       "  ObjectId('657983a788bef09ab851d993'),\n",
       "  ObjectId('657983a788bef09ab851d994'),\n",
       "  ObjectId('657983a788bef09ab851d995'),\n",
       "  ObjectId('657983a788bef09ab851d996'),\n",
       "  ObjectId('657983a788bef09ab851d997'),\n",
       "  ObjectId('657983a788bef09ab851d998'),\n",
       "  ObjectId('657983a788bef09ab851d999'),\n",
       "  ObjectId('657983a788bef09ab851d99a'),\n",
       "  ObjectId('657983a788bef09ab851d99b'),\n",
       "  ObjectId('657983a788bef09ab851d99c'),\n",
       "  ObjectId('657983a788bef09ab851d99d'),\n",
       "  ObjectId('657983a788bef09ab851d99e'),\n",
       "  ObjectId('657983a788bef09ab851d99f'),\n",
       "  ObjectId('657983a788bef09ab851d9a0'),\n",
       "  ObjectId('657983a788bef09ab851d9a1'),\n",
       "  ObjectId('657983a788bef09ab851d9a2'),\n",
       "  ObjectId('657983a788bef09ab851d9a3'),\n",
       "  ObjectId('657983a788bef09ab851d9a4'),\n",
       "  ObjectId('657983a788bef09ab851d9a5'),\n",
       "  ObjectId('657983a788bef09ab851d9a6'),\n",
       "  ObjectId('657983a788bef09ab851d9a7'),\n",
       "  ObjectId('657983a788bef09ab851d9a8'),\n",
       "  ObjectId('657983a788bef09ab851d9a9'),\n",
       "  ObjectId('657983a788bef09ab851d9aa'),\n",
       "  ObjectId('657983a788bef09ab851d9ab'),\n",
       "  ObjectId('657983a788bef09ab851d9ac'),\n",
       "  ObjectId('657983a788bef09ab851d9ad'),\n",
       "  ObjectId('657983a788bef09ab851d9ae'),\n",
       "  ObjectId('657983a788bef09ab851d9af'),\n",
       "  ObjectId('657983a788bef09ab851d9b0'),\n",
       "  ObjectId('657983a788bef09ab851d9b1'),\n",
       "  ObjectId('657983a788bef09ab851d9b2'),\n",
       "  ObjectId('657983a788bef09ab851d9b3'),\n",
       "  ObjectId('657983a788bef09ab851d9b4'),\n",
       "  ObjectId('657983a788bef09ab851d9b5'),\n",
       "  ObjectId('657983a788bef09ab851d9b6'),\n",
       "  ObjectId('657983a788bef09ab851d9b7'),\n",
       "  ObjectId('657983a788bef09ab851d9b8'),\n",
       "  ObjectId('657983a788bef09ab851d9b9'),\n",
       "  ObjectId('657983a788bef09ab851d9ba'),\n",
       "  ObjectId('657983a788bef09ab851d9bb'),\n",
       "  ObjectId('657983a788bef09ab851d9bc'),\n",
       "  ObjectId('657983a788bef09ab851d9bd'),\n",
       "  ObjectId('657983a788bef09ab851d9be'),\n",
       "  ObjectId('657983a788bef09ab851d9bf'),\n",
       "  ObjectId('657983a788bef09ab851d9c0'),\n",
       "  ObjectId('657983a788bef09ab851d9c1'),\n",
       "  ObjectId('657983a788bef09ab851d9c2'),\n",
       "  ObjectId('657983a788bef09ab851d9c3'),\n",
       "  ObjectId('657983a788bef09ab851d9c4'),\n",
       "  ObjectId('657983a788bef09ab851d9c5'),\n",
       "  ObjectId('657983a788bef09ab851d9c6'),\n",
       "  ObjectId('657983a788bef09ab851d9c7'),\n",
       "  ObjectId('657983a788bef09ab851d9c8'),\n",
       "  ObjectId('657983a788bef09ab851d9c9'),\n",
       "  ObjectId('657983a788bef09ab851d9ca'),\n",
       "  ObjectId('657983a788bef09ab851d9cb'),\n",
       "  ObjectId('657983a788bef09ab851d9cc'),\n",
       "  ObjectId('657983a788bef09ab851d9cd'),\n",
       "  ObjectId('657983a788bef09ab851d9ce'),\n",
       "  ObjectId('657983a788bef09ab851d9cf'),\n",
       "  ObjectId('657983a788bef09ab851d9d0'),\n",
       "  ObjectId('657983a788bef09ab851d9d1'),\n",
       "  ObjectId('657983a788bef09ab851d9d2'),\n",
       "  ObjectId('657983a788bef09ab851d9d3'),\n",
       "  ObjectId('657983a788bef09ab851d9d4'),\n",
       "  ObjectId('657983a788bef09ab851d9d5'),\n",
       "  ObjectId('657983a788bef09ab851d9d6'),\n",
       "  ObjectId('657983a788bef09ab851d9d7'),\n",
       "  ObjectId('657983a788bef09ab851d9d8'),\n",
       "  ObjectId('657983a788bef09ab851d9d9'),\n",
       "  ObjectId('657983a788bef09ab851d9da'),\n",
       "  ObjectId('657983a788bef09ab851d9db'),\n",
       "  ObjectId('657983a788bef09ab851d9dc'),\n",
       "  ObjectId('657983a788bef09ab851d9dd')],\n",
       " TaskWorkflow(database=<superduperdb.base.datalayer.Datalayer object at 0x154df6410>, G=<networkx.classes.digraph.DiGraph object at 0x1562ed610>))"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import numpy\n",
    "from datasets import load_dataset\n",
    "from superduperdb import Document as D\n",
    "\n",
    "# Load IMDb dataset\n",
    "data = load_dataset(\"imdb\")\n",
    "\n",
    "# Set the number of data points for training (adjust as needed)\n",
    "N_DATAPOINTS = 100\n",
    "\n",
    "# Prepare training data\n",
    "train_data = [\n",
    "    D({'_fold': 'train', **data['train'][int(i)]})\n",
    "    for i in numpy.random.permutation(len(data['train']))\n",
    "][:N_DATAPOINTS]\n",
    "\n",
    "# Prepare validation data\n",
    "valid_data = [\n",
    "    D({'_fold': 'valid', **data['test'][int(i)]})\n",
    "    for i in numpy.random.permutation(len(data['test']))\n",
    "][:N_DATAPOINTS // 10]\n",
    "\n",
    "# Insert training data into the 'collection' SuperDuperDB collection\n",
    "db.execute(collection.insert_many(train_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00a92214",
   "metadata": {},
   "source": [
    "## Run Model\n",
    "\n",
    "We'll create a SuperDuperDB model based on the `sentence_transformers` library. This demonstrates that you don't necessarily need a native SuperDuperDB integration with a model library to leverage its power. We configure the `Model wrapper` to work with the `SentenceTransformer class`. After configuration, we can link the model to a collection and daemonize the model with the `listen=True` keyword."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "35c9276c-9418-4d24-932d-c8a36fb60a42",
   "metadata": {},
   "outputs": [],
   "source": [
    "import sentence_transformers\n",
    "from sklearn.svm import SVC\n",
    "\n",
    "from superduperdb import Model\n",
    "from superduperdb.ext.numpy import array\n",
    "from superduperdb.ext.sklearn import Estimator\n",
    "from superduperdb import superduper\n",
    "from superduperdb.components.stack import Stack\n",
    "\n",
    "\n",
    "# Create a SuperDuperDB Model using Sentence Transformers\n",
    "m1 = Model(\n",
    "    identifier='all-MiniLM-L6-v2',\n",
    "    object=sentence_transformers.SentenceTransformer('all-MiniLM-L6-v2'),\n",
    "    encoder=array('float32', shape=(384,)),\n",
    "    predict_method='encode',\n",
    "    batch_predict=True,\n",
    "    predict_X='text',\n",
    "    predict_select=collection.find(),\n",
    "    predict_kwargs={'show_progress_bar': True},\n",
    ")\n",
    "\n",
    "\n",
    "# Create a SuperDuperDB model with an SVC classifier\n",
    "m2 = Estimator(\n",
    "    'svc',\n",
    "    object=SVC(gamma='scale', class_weight='balanced', C=100, verbose=True),\n",
    "    postprocess=lambda x: int(x),\n",
    "    train_X='_outputs.text.all-MiniLM-L6-v2.0',\n",
    "    train_y='label',\n",
    "    train_select=collection.find(),\n",
    "    predict_X='_outputs.text.all-MiniLM-L6-v2.0',\n",
    "    predict_select=collection.find({'_fold': 'valid'})\n",
    ")\n",
    "\n",
    "stack = Stack('my-stack', components=[m1, m2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "a8ac17d7-5f1d-4e31-9c49-f3c63c40104c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[32m 2023-Dec-13 11:12:58.50\u001b[0m| \u001b[1mINFO    \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.components.model\u001b[0m:\u001b[36m231 \u001b[0m | \u001b[1mAdding model all-MiniLM-L6-v2 to db\u001b[0m\n",
      "\u001b[32m 2023-Dec-13 11:12:58.50\u001b[0m| \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.base.datalayer\u001b[0m:\u001b[36m873 \u001b[0m | \u001b[34m\u001b[1mmodel/all-MiniLM-L6-v2/0 already exists - doing nothing\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100it [00:00, 49531.22it/s]\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "409dbac717864e95a3a4f0d5c24d4bd5",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Batches:   0%|          | 0/4 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\u001b[32m 2023-Dec-13 11:13:13.25\u001b[0m| \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.base.datalayer\u001b[0m:\u001b[36m873 \u001b[0m | \u001b[34m\u001b[1mmodel/svc/0 already exists - doing nothing\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 92/92 [00:00<00:00, 58263.02it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[LibSVM].*\n",
      "optimization finished, #iter = 178\n",
      "obj = -51.025101, rho = 0.245493\n",
      "nSV = 89, nBSV = 0\n",
      "Total nSV = 89\n",
      "\u001b[32m 2023-Dec-13 11:13:13.26\u001b[0m| \u001b[1mINFO    \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.components.model\u001b[0m:\u001b[36m231 \u001b[0m | \u001b[1mAdding model svc to db\u001b[0m\n",
      "\u001b[32m 2023-Dec-13 11:13:13.26\u001b[0m| \u001b[34m\u001b[1mDEBUG   \u001b[0m | \u001b[36mDuncans-MacBook-Pro.local\u001b[0m| \u001b[36m81014ec8-a24e-4ebd-96b6-651004cd7edb\u001b[0m| \u001b[36msuperduperdb.base.datalayer\u001b[0m:\u001b[36m873 \u001b[0m | \u001b[34m\u001b[1mmodel/svc/0 already exists - doing nothing\u001b[0m\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "8it [00:00, 13628.93it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "([None, SVC(C=100, class_weight='balanced', verbose=True), None],\n",
       " Stack(identifier='my-stack', components=[Model(identifier='all-MiniLM-L6-v2', encoder=Encoder(identifier='numpy.float32[384]', decoder=<Artifact artifact=<superduperdb.ext.numpy.encoder.DecodeArray object at 0x156324190> serializer=dill>, encoder=<Artifact artifact=<superduperdb.ext.numpy.encoder.EncodeArray object at 0x156324e50> serializer=dill>, shape=(384,), load_hybrid=True), output_schema=None, flatten=False, preprocess=None, postprocess=None, collate_fn=None, batch_predict=True, takes_context=False, metrics=(), model_update_kwargs={}, validation_sets=None, predict_X='text', predict_select=<superduperdb.backends.mongodb.query.MongoCompoundSelect[\n",
       "     \u001b[92m\u001b[1mtransfer.find({'_id': \"{'$in': '[657983a788bef09ab851d97a, 657983a788bef09ab851d97b, 657983a788bef09ab851d97c, 657983a788bef09ab851d97d, 657983a788bef09ab851d97e, 657983a788bef09ab851d97f, 657983a788bef09ab851d980, 657983a788bef09ab851d981, 657983a788bef09ab851d982, 657983a788bef09ab851d983, 657983a788bef09ab851d984, 657983a788bef09ab851d985, 657983a788bef09ab851d986, 657983a788bef09ab851d987, 657983a788bef09ab851d988, 657983a788bef09ab851d989, 657983a788bef09ab851d98a, 657983a788bef09ab851d98b, 657983a788bef09ab851d98c, 657983a788bef09ab851d98d, 657983a788bef09ab851d98e, 657983a788bef09ab851d98f, 657983a788bef09ab851d990, 657983a788bef09ab851d991, 657983a788bef09ab851d992, 657983a788bef09ab851d993, 657983a788bef09ab851d994, 657983a788bef09ab851d995, 657983a788bef09ab851d996, 657983a788bef09ab851d997, 657983a788bef09ab851d998, 657983a788bef09ab851d999, 657983a788bef09ab851d99a, 657983a788bef09ab851d99b, 657983a788bef09ab851d99c, 657983a788bef09ab851d99d, 657983a788bef09ab851d99e, 657983a788bef09ab851d99f, 657983a788bef09ab851d9a0, 657983a788bef09ab851d9a1, 657983a788bef09ab851d9a2, 657983a788bef09ab851d9a3, 657983a788bef09ab851d9a4, 657983a788bef09ab851d9a5, 657983a788bef09ab851d9a6, 657983a788bef09ab851d9a7, 657983a788bef09ab851d9a8, 657983a788bef09ab851d9a9, 657983a788bef09ab851d9aa, 657983a788bef09ab851d9ab, 657983a788bef09ab851d9ac, 657983a788bef09ab851d9ad, 657983a788bef09ab851d9ae, 657983a788bef09ab851d9af, 657983a788bef09ab851d9b0, 657983a788bef09ab851d9b1, 657983a788bef09ab851d9b2, 657983a788bef09ab851d9b3, 657983a788bef09ab851d9b4, 657983a788bef09ab851d9b5, 657983a788bef09ab851d9b6, 657983a788bef09ab851d9b7, 657983a788bef09ab851d9b8, 657983a788bef09ab851d9b9, 657983a788bef09ab851d9ba, 657983a788bef09ab851d9bb, 657983a788bef09ab851d9bc, 657983a788bef09ab851d9bd, 657983a788bef09ab851d9be, 657983a788bef09ab851d9bf, 657983a788bef09ab851d9c0, 657983a788bef09ab851d9c1, 657983a788bef09ab851d9c2, 657983a788bef09ab851d9c3, 657983a788bef09ab851d9c4, 657983a788bef09ab851d9c5, 657983a788bef09ab851d9c6, 657983a788bef09ab851d9c7, 657983a788bef09ab851d9c8, 657983a788bef09ab851d9c9, 657983a788bef09ab851d9ca, 657983a788bef09ab851d9cb, 657983a788bef09ab851d9cc, 657983a788bef09ab851d9cd, 657983a788bef09ab851d9ce, 657983a788bef09ab851d9cf, 657983a788bef09ab851d9d0, 657983a788bef09ab851d9d1, 657983a788bef09ab851d9d2, 657983a788bef09ab851d9d3, 657983a788bef09ab851d9d4, 657983a788bef09ab851d9d5, 657983a788bef09ab851d9d6, 657983a788bef09ab851d9d7, 657983a788bef09ab851d9d8, 657983a788bef09ab851d9d9, 657983a788bef09ab851d9da, 657983a788bef09ab851d9db, 657983a788bef09ab851d9dc, 657983a788bef09ab851d9dd]'}\"}, {})\u001b[0m}\n",
       " ] object at 0x170b23a50>, predict_max_chunk_size=None, predict_kwargs={'show_progress_bar': True}, object=<Artifact artifact=SentenceTransformer(\n",
       "   (0): Transformer({'max_seq_length': 256, 'do_lower_case': False}) with Transformer model: BertModel \n",
       "   (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})\n",
       "   (2): Normalize()\n",
       " ) serializer=dill>, model_to_device_method=None, metric_values={}, predict_method='encode', serializer='dill', device='cpu', preferred_devices=('cuda', 'mps', 'cpu'), training_configuration=None, train_X=None, train_y=None, train_select=None), Estimator(identifier='svc', encoder=None, output_schema=None, flatten=False, preprocess=None, postprocess=<Artifact artifact=<function <lambda> at 0x16dd3ce00> serializer=dill>, collate_fn=None, batch_predict=False, takes_context=False, metrics=(), model_update_kwargs={}, validation_sets=None, predict_X='_outputs.text.all-MiniLM-L6-v2.0', predict_select=<superduperdb.backends.mongodb.query.MongoCompoundSelect[\n",
       "     \u001b[92m\u001b[1mtransfer.find({'_fold': \"'valid'\", '_id': \"{'$in': '[657983a788bef09ab851d987, 657983a788bef09ab851d991, 657983a788bef09ab851d9a8, 657983a788bef09ab851d9c4, 657983a788bef09ab851d9c5, 657983a788bef09ab851d9cb, 657983a788bef09ab851d9d8, 657983a788bef09ab851d9da]'}\"}, {})\u001b[0m}\n",
       " ] object at 0x16a8e8f90>, predict_max_chunk_size=None, predict_kwargs=None, object=<Artifact artifact=SVC(C=100, class_weight='balanced', verbose=True) serializer=dill>, model_to_device_method=None, metric_values={}, predict_method='predict', serializer='dill', device='cpu', preferred_devices=('cuda', 'mps', 'cpu'), training_configuration=None, train_X='_outputs.text.all-MiniLM-L6-v2.0', train_y='label', train_select=<superduperdb.backends.mongodb.query.MongoCompoundSelect[\n",
       "     \u001b[92m\u001b[1mtransfer.find({'_fold': \"'train'\"}, {})\u001b[0m}\n",
       " ] object at 0x17061c390>)]))"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "db.add(stack)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "67b156c1",
   "metadata": {},
   "source": [
    "## Verification\n",
    "\n",
    "To verify that the process has worked, we can sample a few records to inspect the sanity of the predictions."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "76958a1e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "this was a favorite Christmas Special that I wish that they would release on vhs or dvd , since my 3\n",
      "1\n"
     ]
    }
   ],
   "source": [
    "# Query a random document from the 'collection' SuperDuperDB collection\n",
    "r = next(db.execute(collection.aggregate([{'$match': {'_fold': 'valid'}},{'$sample': {'size': 1}}])))\n",
    "\n",
    "# Print a portion of the 'text' field from the random document\n",
    "print(r['text'][:100])\n",
    "\n",
    "# Print the prediction made by the SVC model stored in '_outputs'\n",
    "print(r['_outputs']['text']['svc']['0'])"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
